from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm
import json
import torch
import random
import numpy as np
import pandas as pd

# reproducibility
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

df = pd.read_csv('prompts_falsepresupposition.csv')
prompts = df['prompt']
valid_indices = list(range(101, len(prompts)))
selected_indices = list(range(100)) + random.sample(valid_indices, 150) 

results = []
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-72B-Instruct")
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,        # Enable 8-bit quantization
    llm_int8_threshold=6.0,   # (Optional) Default threshold for LLM.int8()
    llm_int8_skip_modules=None, # (Optional) Skip quantization for specific modules
)

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-72B-Instruct", device_map="auto", quantization_config=bnb_config)

for ind in selected_indices:  
    messages = [{"role": "user", "content": prompts[ind] + ". "}]

    input_ids = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, return_tensors="pt"
    ).to(model.device)
    
    terminators = [
        tokenizer.eos_token_id,
    ]

    outputs = model.generate(
        input_ids, max_new_tokens=500, eos_token_id=terminators, 
        do_sample=True, temperature=0.9, pad_token_id=tokenizer.eos_token_id, num_return_sequences=5
    )
    generations = [tokenizer.decode(decoded[input_ids.shape[-1]:], skip_special_tokens=True).strip() for decoded in outputs]
    results.append({'Prompt': prompts[ind], 'Responses': generations})

with open('fp_qwen72b.json', 'w') as json_file:
    json.dump(results, json_file, indent=4)
